【小ネタ】[Amazon SageMaker] 既存のモデルを使用した増分学習をJupyter Notebookでやってみました
1 はじめに
CX事業本部の平内(SIN)です
Amazon SageMaker(以下、SageMaker)では、既存のモデルを元に学習を開始する増分学習がサポートされており、ここDevelopers.IOでも既に紹介されています。
上記は、コンソールから物体検出の増分学習の要領が、紹介されていますが、これを、単に、Jupyter Notebookでやってみた記録です。
Jupyter Notebookには、物体検出の増分学習のサンプルとして、Amazon SageMaker Object Detection Incremental Trainingがあり、データ形式がRecordIOとなっていますが、今回試したのは、JSON形式のデータセットです。
参考:Now easily perform incremental learning on Amazon SageMaker
2 Jupyter Notebook
Jupyter Notebookの内容は、以下の通りです。
(1) Setup
最初に、ロールの取得、データの入出力用S3バケット(プレフィックス)の定義、object-detectionのDockerイメージの取得を行います。これは、増分学習に限らず共通です。
%%time import sagemaker from sagemaker import get_execution_role # Role role = get_execution_role() print(role) sess = sagemaker.Session() # S3 bucket = 'sagemaker-bucket' prefix = 'my-sample' # Iraning Image from sagemaker.amazon.amazon_estimator import get_image_uri training_image = get_image_uri(sess.boto_region_name, 'object-detection', repo_version="latest") print (training_image)
(2) Data Preparation
続いて、データセットの準備です。 学習データと検証データのS3バケットを設定していますします。こちらも、通常の学習と同様です。
import os import urllib.request # DataSet train_channel = prefix + '/train' validation_channel = prefix + '/validation' train_annotation_channel = prefix + '/train_annotation' validation_annotation_channel = prefix + '/validation_annotation' s3_train_data = 's3://{}/{}'.format(bucket, train_channel) s3_validation_data = 's3://{}/{}'.format(bucket, validation_channel) s3_train_annotation = 's3://{}/{}'.format(bucket, train_annotation_channel) s3_validation_annotation = 's3://{}/{}'.format(bucket, validation_annotation_channel) train_data = sagemaker.session.s3_input(s3_train_data, distribution='FullyReplicated', content_type='image/jpeg', s3_data_type='S3Prefix') validation_data = sagemaker.session.s3_input(s3_validation_data, distribution='FullyReplicated', content_type='image/jpeg', s3_data_type='S3Prefix') train_annotation = sagemaker.session.s3_input(s3_train_annotation, distribution='FullyReplicated', content_type='image/jpeg', s3_data_type='S3Prefix') validation_annotation = sagemaker.session.s3_input(s3_validation_annotation, distribution='FullyReplicated', content_type='image/jpeg', s3_data_type='S3Prefix')
こちらは、継承する元となるモデルの指定です。S3上に配置された、tar.gz形式のファイルを指定します。
# Model s3_model_data = "s3://sagemaker-bucket/my-sample/output/model.tar.gz" model_data = sagemaker.session.s3_input(s3_model_data, distribution='FullyReplicated', content_type='application/x-sagemaker-model', s3_data_type='S3Prefix')
fit()のパラメータとなる、data_channelsを定義します。学習データ、検証データに併せて、modelという名前で、元となるモデルが指定されます。
# データチャンネル data_channels = {'train': train_data, 'validation': validation_data, 'train_annotation': train_annotation, 'validation_annotation':validation_annotation,'model': model_data} # 出力先 s3_output_location = 's3://{}/{}/output'.format(bucket, prefix)
(3) Traning
増分学習では、num_layers、image_shape、num_classesなどのネットワークを定義するハイパーパラメーターは、既存のモデルのトレーニングに使用されたものと同じである必要があります。
od_model = sagemaker.estimator.Estimator(training_image, role, train_instance_count=1, train_instance_type='ml.p3.2xlarge', train_volume_size = 50, train_max_run = 360000, input_mode = 'File', output_path=s3_output_location, sagemaker_session=sess) od_model.set_hyperparameters(base_network='resnet-50', #use_pretrained_model=1, num_classes=3, ### label count ### mini_batch_size=16, epochs=10, ### epoch count ### learning_rate=0.001, lr_scheduler_step='10', lr_scheduler_factor=0.1, optimizer='sgd', momentum=0.9, weight_decay=0.0005, overlap_threshold=0.5, nms_threshold=0.45, image_shape=512, label_width=600, num_training_samples=1808) ### data count ###
学習を開始します。
od_model.fit(inputs=data_channels, logs=True)
ログを見ると、epoch=0で既にscore=0.987となっており、既存のモデルを利用して学習を開始していることが分かります。
#quality_metric: host=algo-1, epoch=0, batch=113 train cross_entropy =(0.19252898495207674) #quality_metric: host=algo-1, epoch=0, batch=113 train smooth_l1 =(0.0477017551396642) #quality_metric: host=algo-1, epoch=0, validation mAP =(0.9876995446829199)
3 最後に
今回は、既存のモデルを使用して学習を追加出来るように、Jupyter Notebookを準備してみました。 正直な所、Epochを何回に指定して学習すれば良いのか、良く分かってないので、無駄に回しすぎたりしないように、増分学習で少しすづ確認しながら進めています。
すべてのコードは、下記に起きました。